Skip to content

[kv_offload+HMA][7/N]: Support register_kv_caches for hybrid models#37853

Merged
orozery merged 5 commits intovllm-project:mainfrom
orozery:kv-offload-register-hybrid-kv-caches
Mar 27, 2026
Merged

[kv_offload+HMA][7/N]: Support register_kv_caches for hybrid models#37853
orozery merged 5 commits intovllm-project:mainfrom
orozery:kv-offload-register-hybrid-kv-caches

Conversation

@orozery
Copy link
Copy Markdown
Collaborator

@orozery orozery commented Mar 23, 2026

This PR extends the offloading connector register_kv_caches function to support KV caches used in hybrid models.

We define a new CanonicalKVCaches class which captures:

  1. The unique set of KV cache tensors (as tensors maybe shared by multiple layers)
  2. Mapping each group to its relevant KV cache data (given by a tensor pointer + page size). The canonical tensors are each of dtype int8 and shape (num_blocks, page_size).

This PR also splits the offloading connector unit tests to multiple files.

@mergify mergify Bot added the v1 label Mar 23, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-executed refactoring to support KV cache offloading for hybrid models. The introduction of CanonicalKVCaches provides a clean abstraction over different KV cache layouts, improving modularity and simplifying the transfer handler logic. The accompanying test refactoring is also a good improvement.

I have one high-severity comment regarding an inconsistency in the data type of the canonicalized tensor, which could lead to maintenance issues in the future. Addressing this would make the new abstraction more robust and easier to reason about.

Comment thread vllm/distributed/kv_transfer/kv_connector/v1/offloading/worker.py
@orozery orozery force-pushed the kv-offload-register-hybrid-kv-caches branch from 4b710fe to cc8dc31 Compare March 23, 2026 06:04
@mergify mergify Bot removed the needs-rebase label Mar 23, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 24, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @orozery.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 24, 2026
@orozery orozery force-pushed the kv-offload-register-hybrid-kv-caches branch from cc8dc31 to 943e72a Compare March 24, 2026 09:01
@mergify mergify Bot removed the needs-rebase label Mar 24, 2026
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @orozery , left only a few comments here mainly because I can understand what you're doing, but I am afraid I don't fully understand why you're doing it.

Could you elaborate on why we would need a CanonicalKVCacheTensor here and how is that abstraction making life easier for you in code/transfer?

Comment on lines +113 to +119
test_shape = attn_backends[layer_name].get_kv_cache_shape(
num_blocks=1234,
block_size=16,
num_kv_heads=1,
head_size=256,
)
num_blocks_logical_dim = test_shape.index(1234)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a get_kv_cache_block_dim API now we can use

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I see that I use test_shape a bit below for asserting (2, ...) for flash attention.

Comment on lines +196 to +203
page_size_bytes[layer_name] = layer_kv_cache_spec.page_size_bytes
unpadded_page_size_bytes[layer_name] = replace(
layer_kv_cache_spec, page_size_padded=None
).page_size_bytes

else:
raise NotImplementedError

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any way we can make this if-elif-else simpler by eg factoring out common assignment at the end?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All cases need to assign to tensors_per_block, page_size_bytes and unpadded_page_size_bytes.
To move assignments out I will need to introduce a local variable per each of these 3 dictionaries.
I don't see it simplifying.
But maybe I missed your point...

@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Mar 24, 2026

Could you elaborate on why we would need a CanonicalKVCacheTensor here and how is that abstraction making life easier for you in code/transfer?

The offloading connector supports pluggable backends (e.g. CPU backend, file-system in the future).
We want the backend interface to be simple as possible, and solve complexities once for all backends, instead of letting each backend handle complexity by itself.

One such complexity is registering the GPU KV caches.
The backend needs to deal with things like: split-k-v (flash-attention), MLA layout, mamba packed state, kernel block size,...

To avoid all of these complexities, we define this class:

class CanonicalKVCaches:
    """
    Canonicalized block-level representation of the KV caches.

    Composed of:
        - Unique list of KV cache data tensors,
          each with shape (num_blocks, page_size_in_bytes) and int8 dtype.
        - Per-group data references of the tensors.
          i.e. how each KV cache group maps to the tensors.
    """

This allows the backend to easily use the KV caches, without having to deal with all of the above complexities.

Before this PR, we handled SOME (but not all, e.g. mamba) of these complexities inside the CPU backend (cpu_gpu.py).
You can see that as a result of this PR, many lines of code are removed from cpu_gpu.py as it is given CanonicalKVCaches instead of the previous kv_caches: dict[str, torch.Tensor].

So basically, the offloading connector takes responsibility for "translating" the complex-layout kv_caches into a simple, canonical, easy to work with CanonicalKVCaches.

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All cases need to assign to tensors_per_block, page_size_bytes and unpadded_page_size_bytes.

Yes that's the pattern I would like to factor out..
Possibly the canonical torch.tensor/raw creation in particular which is quite verbose.

Anyways I am unblocking given this is minor and could maybe be simply wrapped in a util or constructor method for CanonicalKVCaches, in order to keep core logic in worker file as lean as possible.

I'll leave it to you to shape as you see fit for best maintainability.

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 25, 2026
@NickLucche NickLucche closed this Mar 25, 2026
@NickLucche
Copy link
Copy Markdown
Collaborator

NickLucche commented Mar 25, 2026

@orozery let's check CI

sorry misclicked on closing somehow >.<

@NickLucche NickLucche reopened this Mar 25, 2026
@orozery orozery force-pushed the kv-offload-register-hybrid-kv-caches branch from 89f5cca to 236714e Compare March 26, 2026 05:07
@orozery
Copy link
Copy Markdown
Collaborator Author

orozery commented Mar 26, 2026

@NickLucche there was an issue with unit tests (incl. nixl connector) that were using set_kv_cache_layout and were affecting each other as get_kv_cache_layout uses @functools.lru_cache.
I fixed it here for both our tests.

@orozery orozery force-pushed the kv-offload-register-hybrid-kv-caches branch from 236714e to 112c22b Compare March 26, 2026 07:08
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 26, 2026

Hi @orozery, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

This commit extends the offloading connector register_kv_caches function
to support KV caches used in hybrid models.
We define a new CanonicalKVCaches class which captures:
1. The unique set of KV cache tensors (as tensors maybe shared by multiple layers)
2. Mapping each group to its relevant KV cache data (given by a tensor pointer + page size).
The canonical tensors are each of dtype int8 and shape (num_blocks, page_size).
This commit also splits the offloading connector unit tests to multiple files.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
@orozery orozery force-pushed the kv-offload-register-hybrid-kv-caches branch from 112c22b to 53eca8a Compare March 26, 2026 09:49
@rarepepi
Copy link
Copy Markdown

excited for this! i think its very needed for my deepseekv3.2 nvfp4 setup @orozery ty 🙏

@dannyboycrypt0
Copy link
Copy Markdown

i believe the merging of this PR would solve one of the issues we're working on right now as well.

@orozery orozery merged commit 7cc302d into vllm-project:main Mar 27, 2026
69 checks passed
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…llm-project#37853)

Signed-off-by: Or Ozeri <oro@il.ibm.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
…llm-project#37853)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Rishi Puri <riship@nvidia.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Apr 8, 2026
### What this PR does / why we need it?
Main2main upgrade vllm to 0330
fix breaks:
1. vllm-project/vllm#37728 add clear_row method
for BlockTable
2. vllm-project/vllm#37975 Adapt
GatedDeltaNetAttention Refactor
3. vllm-project/vllm#37698 update
maybe_update_config in vllm_ascend/quantization/modelslim_config.py to
adapt this pr change
4. vllm-project/vllm#37880 This pr add the feat
where we can set different moe backends between draft and target model,
we should overwrite it in the draft proposer
5. vllm-project/vllm#37853 for now just to skip
test_cpu_offloading.py test case utils this feature has been adapted.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

CI

- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@29e4870

---------

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Co-authored-by: Claude Code <claude@anthropic.com>
Co-authored-by: Claude Code <noreply@anthropic.com>
Co-authored-by: wxsIcey <1790571317@qq.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants